import numpy as np
import paddle.fluid as fluid
from paddle.fluid.dygraph.base import to_variable

def sigmoid(x):
    """
    功能:
        计算sigmoid函数
    输入:
        x - 输入数值
    输出:
        y - 输出数值
    """
    return 0.5 * (1.0 + np.tanh(0.5 * x))

# def sigmoid(x):
#     return 1.0 / (1.0 + np.exp(-x))

def get_box_iou_xywh(box1, box2):
    """
    功能: 
        计算边框交并比值
    输入: 
        box1 - 边界框1
        box2 - 边界框2
    输出:
        iou  - 交并比值
    """
    # 计算交集面积
    x1_min = box1[0] - box1[2]/2.0
    y1_min = box1[1] - box1[3]/2.0
    x1_max = box1[0] + box1[2]/2.0
    y1_max = box1[1] + box1[3]/2.0
    
    x2_min = box2[0] - box2[2]/2.0
    y2_min = box2[1] - box2[3]/2.0
    x2_max = box2[0] + box2[2]/2.0
    y2_max = box2[1] + box2[3]/2.0
    
    x_min = np.maximum(x1_min, x2_min)
    y_min = np.maximum(y1_min, y2_min)
    x_max = np.minimum(x1_max, x2_max)
    y_max = np.minimum(y1_max, y2_max)
    
    w = np.maximum(x_max - x_min, 0.0)
    h = np.maximum(y_max - y_min, 0.0)
    
    intersection = w * h # 交集面积
    
    # 计算并集面积
    s1 = box1[2] * box1[3]
    s2 = box2[2] * box2[3]
    
    union = s1 + s2 - intersection # 并集面积
    
    # 计算交并比
    iou = intersection / union
    
    return iou

def get_ignore_label(infer, gtbox, num_classes, anchor_size, anchor_mask, ignore_threshold, downsample_ratio):
    """
    功能:
        计算大于阈值的物体标签，设置为-1，不计算损失值
    输入:
        infer            - 特征图像
        gtbox            - 真实边框
        num_classes      - 类别数量
        anchor_size      - 锚框大小
        anchor_mask      - 锚框掩码
        ignore_threshold - 忽略阈值
        downsample_ratio - 下采样率
    输出:
        lbobj            - 物体标签
    """
    # 调整特征形状
    batch_size = infer.shape[0]   # 特征批数
    num_rows   = infer.shape[2]   # 特征行数
    num_cols   = infer.shape[3]   # 特征列数
    num_anchor = len(anchor_mask) # 锚框数量
    
    infer = infer.numpy()
    infer = infer.reshape([-1, num_anchor, 5 + num_classes, num_rows, num_cols]) # 转换特征形状
    
    # 计算预测边框
    pdloc = infer[:, :, 0:4, :, :]        # 获取预测位置
    pdbox = np.zeros(pdloc.shape)         # 预测边框数组
    image_h = num_rows * downsample_ratio # 预测图像高度
    image_w = num_cols * downsample_ratio # 预测图像宽度
    
    for m in range(batch_size): # 遍历图像
        for i in range(num_rows): # 遍历行数
            for j in range(num_cols): # 遍历列数
                for k in range(num_anchor): # 遍历锚框
                    # 获取边框大小
                    anchor_w = anchor_size[2 * anchor_mask[k]]     # 锚框宽度
                    anchor_h = anchor_size[2 * anchor_mask[k] + 1] # 锚框高度
                    
                    # 设置预测边框
                    pdbox[m, k, 0, i, j] = j        # 预测边框cx
                    pdbox[m, k, 1, i, j] = i        # 预测边框cy
                    pdbox[m, k, 2, i, j] = anchor_w # 预测边框pw
                    pdbox[m, k, 3, i, j] = anchor_h # 预测边框ph
                    
    pdbox[:, :, 0, :, :] = (pdbox[:, :, 0, :, :] + sigmoid(pdloc[:, :, 0, :, :])) / num_cols # 预测边框x=cx + dx
    pdbox[:, :, 1, :, :] = (pdbox[:, :, 1, :, :] + sigmoid(pdloc[:, :, 1, :, :])) / num_rows # 预测边框y=cy + dy
    pdbox[:, :, 2, :, :] = (pdbox[:, :, 2, :, :] * np.exp(pdloc[:, :, 2, :, :])) / image_w   # 预测边框w=pw * exp(tw)
    pdbox[:, :, 3, :, :] = (pdbox[:, :, 3, :, :] * np.exp(pdloc[:, :, 3, :, :])) / image_h   # 预测边框h=ph * exp(th)
    pdbox = np.clip(pdbox, 0.0, 1.0) # 限制预测边框范围为[0,1]
    
    # 计算物体标签
    lbobj = np.zeros([batch_size, num_anchor, num_rows, num_cols]) # 物体标签
    for m in range(batch_size): # 遍历图像
        for n in range(len(gtbox[m])): # 遍历真实边框
            # 获取真实边框
            gtbox_x = gtbox[m][n][0] # 真实边框gtx
            gtbox_y = gtbox[m][n][1] # 真实边框gty
            gtbox_w = gtbox[m][n][2] # 真实边框gtw
            gtbox_h = gtbox[m][n][3] # 真实边框gth
            
            # 是否存在物体
            if gtbox_w < 1e-3 or gtbox_h < 1e-3:
                continue
            
            # 获取预测边框
            pdbox_x = pdbox[m, :, 0, :, :] # 预测边框pdx
            pdbox_y = pdbox[m, :, 1, :, :] # 预测边框pdy
            pdbox_w = pdbox[m, :, 2, :, :] # 预测边框pdw
            pdbox_h = pdbox[m, :, 3, :, :] # 预测边框pdh
            
            # 计算交并比值
            box1 = [pdbox_x, pdbox_y, pdbox_w, pdbox_h] # 设置预测边框
            box2 = [gtbox_x, gtbox_y, gtbox_w, gtbox_h] # 设置真实边框
            ious = get_box_iou_xywh(box1, box2)         # 计算交并比值
            
            # 计算物体标签
            index = np.where(ious > ignore_threshold) # 大于阈值标签索引
            lbobj[m][index] = -1                      # 大于阈值物体标签
    
    return lbobj

def get_predict_label(infer, gtbox, gtcls, num_classes, anchor_size, anchor_mask, ignore_threshold, downsample_ratio):
    """
    功能:
        计算预测标签
    输入:
        infer            - 特征图像
        gtbox            - 真实边框
        gtcls            - 真实类别
        num_classes      - 类别数量
        anchor_size      - 锚框大小
        anchor_mask      - 锚框掩码
        ignore_threshold - 忽略阈值
        downsample_ratio - 下采样率
    输出:
        lbloc            - 位置标签
        lbobj            - 物体标签
        lbcls            - 分类标签
        wtloc            - 位置权重
    """
    # 设置标签数据
    batch_size = infer.shape[0]   # 特征批数
    num_rows   = infer.shape[2]   # 特征行数
    num_cols   = infer.shape[3]   # 特征列数
    num_anchor = len(anchor_mask) # 锚框数量
    
    lbloc = np.zeros([batch_size, num_anchor, 4, num_rows, num_cols])           # 位置标签
    lbcls = np.zeros([batch_size, num_anchor, num_classes, num_rows, num_cols]) # 类别标签
    wtloc = np.ones([batch_size, num_anchor, num_rows, num_rows])               # 位置权重
    
    # 大于阈值物体
#     lbobj = np.zeros([batch_size, num_anchor, num_rows, num_cols])              # 物体标签
    lbobj = get_ignore_label(infer, gtbox, num_classes, anchor_size, anchor_mask, ignore_threshold, downsample_ratio)
    
    # 计算预测标签
    image_h = num_rows * downsample_ratio # 原图高度
    image_w = num_cols * downsample_ratio # 原图宽度
    
    for m in range(batch_size): # 遍历图像
        for n in range(len(gtbox[m])): # 遍历真实边框
            # 获取边框坐标
            gtbox_x = gtbox[m][n][0] # 真实边框gtx
            gtbox_y = gtbox[m][n][1] # 真实边框gty
            gtbox_w = gtbox[m][n][2] # 真实边框gtw
            gtbox_h = gtbox[m][n][3] # 真实边框gth
            
            # 是否存在物体
            if gtbox_w < 1e-3 or gtbox_h < 1e-3:
                continue
            
            # 计算交并比值
            iou_list = [] # 交并比值列表
            for k in range(num_anchor): # 遍历锚框
                anchor_w = anchor_size[2 * anchor_mask[k]]     # 锚框宽度
                anchor_h = anchor_size[2 * anchor_mask[k] + 1] # 锚框高度
                box1 = [0.0, 0.0, anchor_w/float(image_w), anchor_h/float(image_h)] # 设置锚框
                box2 = [0.0, 0.0, float(gtbox_w), float(gtbox_h)]                   # 真实边框
                
                iou = get_box_iou_xywh(box1, box2) # 计算交并比值
                iou_list.append(iou)               # 添加交并比值
            
            # 获取锚框序号
            iou_list = np.array(iou_list)   # 转换数据类型
            iou_sort = np.argsort(iou_list) # 交并比值排序
            k = iou_sort[-1]                # 最大锚框序号
            
            # 设置标签坐标
            i = int(gtbox_y * num_rows) # 特征图行坐标
            j = int(gtbox_x * num_cols) # 特征图列坐标
            
            # 设置位置标签
            lbloc[m, k, 0, i, j] = gtbox_x * num_cols - j # 位置标签dx=sigmoid(tx)=gtx-cx
            lbloc[m, k, 1, i, j] = gtbox_y * num_rows - i # 位置标签dy=sigmoid(ty)=gty-cy
            lbloc[m, k, 2, i, j] = np.log(gtbox_w * image_w / anchor_size[2 * anchor_mask[k]])     # 位置标签tw=log(gtw/pw)
            lbloc[m, k, 3, i, j] = np.log(gtbox_h * image_h / anchor_size[2 * anchor_mask[k] + 1]) # 位置标签th=log(gth/ph)
            lbloc = lbloc.astype('float32')
            
            # 设置物体标签
            lbobj[m, k, i, j] = 1
            lbobj = lbobj.astype('float32')
            
            # 设置类别标签
            c = gtcls[m][n] # 标签位置
            lbcls[m, k, c, i, j] = 1.0
            lbcls = lbcls.astype('float32')
            
            # 设置位置权重
            wtloc[m, k, i, j] = 2.0 - gtbox_w * gtbox_h # 调节不同尺寸锚框对损失函数的贡献，作为加权系数和位置损失函数相乘
            wtloc = wtloc.astype('float32')
            
    return lbloc, lbobj, lbcls, wtloc

def get_loss(infer, gtbox, gtcls, num_classes, anchor_size, anchor_mask, ignore_threshold, downsample_ratio):
    """
    功能:
        计算每张图像的损失总和
    输入:
        infer            - 特征图像
        gtbox            - 真实边框
        gtcls            - 真实类别
        num_classes      - 类别数量
        anchor_size      - 锚框大小
        anchor_mask      - 锚框掩码
        ignore_threshold - 忽略阈值
        downsample_ratio - 下采样率
    输出:
        sum_loss         - 损失总和
    """
    # 计算预测标签
    lbloc, lbobj, lbcls, wtloc = get_predict_label(infer, gtbox, gtcls, 
                                                   num_classes, anchor_size, anchor_mask, ignore_threshold, downsample_ratio)
    
    # 转换标签格式
    lbloc = to_variable(lbloc)
    lbobj = to_variable(lbobj)
    lbcls = to_variable(lbcls)
    wtloc = to_variable(wtloc)
    
    lbloc.stop_gradient=True # 停止梯度计算
    lbobj.stop_gradient=True # 停止梯度计算
    lbcls.stop_gradient=True # 停止梯度计算
    wtloc.stop_gradient=True # 停止梯度计算
    
    # 转换特征格式
    infer = fluid.layers.reshape(infer, [-1, len(anchor_mask), 5 + num_classes, infer.shape[2], infer.shape[3]])
    
    # 正样本值位置
    ploss = lbobj > 0                           # 正样本值位置
    ploss = fluid.layers.cast(ploss, 'float32') # 转换数据格式
    ploss.stop_gradient=True                    # 停止梯度计算
    
    # 计算位置损失
    pdloc_dx = infer[:, :, 0, :, :] # 预测位置dx=sigmoid(tx)
    pdloc_dy = infer[:, :, 1, :, :] # 预测位置dy=sigmoid(ty)
    pdloc_tw = infer[:, :, 2, :, :] # 预测位置tw
    pdloc_th = infer[:, :, 3, :, :] # 预测位置th
    
    lbloc_dx = lbloc[:, :, 0, :, :] # 标签位置dx=sigmoid(tx)
    lbloc_dy = lbloc[:, :, 1, :, :] # 标签位置dy=sigmoid(ty)
    lbloc_tw = lbloc[:, :, 2, :, :] # 标签位置tw
    lbloc_th = lbloc[:, :, 3, :, :] # 标签位置th
    
    loss_loc_dx = fluid.layers.sigmoid_cross_entropy_with_logits(pdloc_dx, lbloc_dx) # 计算位置损失dx
    loss_loc_dy = fluid.layers.sigmoid_cross_entropy_with_logits(pdloc_dy, lbloc_dy) # 计算位置损失dy
    loss_loc_tw = fluid.layers.abs(pdloc_tw - lbloc_tw)                              # 计算位置损失tw
    loss_loc_th = fluid.layers.abs(pdloc_th - lbloc_th)                              # 计算位置损失th
    
    loss_loc = loss_loc_dx + loss_loc_dy + loss_loc_tw + loss_loc_th # 计算总的位置损失
    loss_loc = loss_loc * wtloc                                      # 带权重的位置损失
    loss_loc = loss_loc * ploss                                      # 正样本的位置损失
    
    # 计算物体损失
    pdobj = infer[:, :, 4, :, :]                                                             # 物体预测数值
    loss_obj = fluid.layers.sigmoid_cross_entropy_with_logits(pdobj, lbobj, ignore_index=-1) # 忽略标签为-1梯度
    
    # 计算类别损失
    pdcls = infer[:, :, 5:5+num_classes, :, :]                              # 类别预测数值
    loss_cls = fluid.layers.sigmoid_cross_entropy_with_logits(pdcls, lbcls) # 计算类别损失
    loss_cls = fluid.layers.reduce_sum(loss_cls, dim=2)                     # 对通道维损失求和
    loss_cls = loss_cls * ploss                                             # 正样本的类别损失
    
    # 计算平均损失
    sum_loss = loss_loc + loss_obj + loss_cls                   # 计算损失总和
    sum_loss = fluid.layers.reduce_sum(sum_loss, dim=[1, 2, 3]) # 每张图像损失
    
    return sum_loss

def get_sum_loss(infer, gtbox, gtcls, num_classes, anchor_size, anchor_mask, ignore_threshold, downsample_ratio):
    """
    功能:
        计算三个输出的损失总和
    输入:
        infer            - 特征列表
        gtbox            - 真实边框
        gtcls            - 真实类别
        num_classes      - 类别数量
        anchor_size      - 锚框大小
        anchor_mask      - 锚框掩码
        ignore_threshold - 样本阈值
        downsample_ratio - 下采样率
    输出:
        sum_loss         - 平均损失总和
    """
    # 计算平均损失
    loss_list = [] # 平均损失列表
    for i in range(len(infer)):
        # 计算平均损失
        loss = get_loss(infer[i], gtbox, gtcls, num_classes, anchor_size, anchor_mask[i], ignore_threshold, downsample_ratio)
        loss_list.append(fluid.layers.reduce_mean(loss)) # 添加损失列表
        
        # 减小下采样率
        downsample_ratio //= 2 # 减小下采样率
    
    # 计算损失总和
    sum_loss = sum(loss_list)
    
    return sum_loss

# def get_sum_loss(infer, gtbox, gtcls, num_classes, anchor_size, anchor_mask, ignore_threshold, downsample_ratio):
#     # 计算平均损失
#     loss_list = [] # 平均损失列表
#     gtbox = to_variable(gtbox)
#     gtcls = to_variable(gtcls)
    
#     for i in range(len(infer)):
#         # 计算平均损失
#         loss = fluid.layers.yolov3_loss(
#             x=infer[i],
#             gt_box=gtbox,
#             gt_label=gtcls,
#             class_num=num_classes,
#             anchors=anchor_size,
#             anchor_mask=anchor_mask[i],
#             ignore_thresh=ignore_threshold,
#             downsample_ratio=downsample_ratio,
#             use_label_smooth=False)
#         loss_list.append(fluid.layers.reduce_mean(loss)) # 添加损失列表
        
#         # 减小下采样率
#         downsample_ratio //= 2 # 减小下采样率
    
#     # 计算损失总和
#     sum_loss = sum(loss_list)
    
#     return sum_loss